package redisstore

import (
	"context"
	"fmt"
	"gitee.com/zackeus/go-zero/core/stores/redis"
	"gitee.com/zackeus/goutil"
	"strings"
	"time"
)

type (
	// RedisStore session 存储(基于 redis)
	RedisStore struct {
		client *redis.Redis
		up     string
		sp     string
	}
)

// New 返回一个新的 RedisStore 实例
// to a redigo connection pool. See https://godoc.org/github.com/gomodule/redigo/redis#Pool.
func New(prefix string, client *redis.Redis) (*RedisStore, error) {
	/* 前缀是否以 : 结尾 */
	if !strings.HasSuffix(prefix, ":") {
		prefix = prefix + ":"
	}

	return &RedisStore{
		client: client,
		up:     prefix + "user:",
		sp:     prefix + "token:",
	}, nil
}

// Get returns the data for a given session token from the RedisStore instance.
// If the session token is not found or is expired, the returned exists flag
// will be set to false.
func (r *RedisStore) Get(token string) ([]byte, bool, error) {
	return r.GetCtx(context.Background(), token)
}

func (r *RedisStore) GetCtx(ctx context.Context, token string) ([]byte, bool, error) {
	res, err := r.client.GetCtx(ctx, fmt.Sprintf("%s%v", r.sp, token))
	if err != nil {
		return nil, false, err
	}
	/* 数据为空 不存在 */
	if goutil.IsEmpty(res) {
		return nil, false, nil
	}
	return []byte(res), true, nil
}

func (r *RedisStore) Find(uid string) (map[string][]byte, error) {
	return r.FindCtx(context.Background(), uid)
}

func (r *RedisStore) FindCtx(ctx context.Context, uid string) (map[string][]byte, error) {
	tokens, err := r.client.SmembersCtx(ctx, fmt.Sprintf("%s%v", r.up, uid))
	if err != nil {
		return nil, err
	}

	sessions := make(map[string][]byte)
	expiryTokens := make([]any, 0)
	for _, token := range tokens {
		data, exists, err := r.GetCtx(ctx, token)
		if err != nil {
			return nil, err
		}

		if exists {
			sessions[token] = data
		} else {
			/* token 已过期 */
			expiryTokens = append(expiryTokens, token)
		}
	}

	if !goutil.IsEmpty(expiryTokens) {
		/* 删除已过期的 token 集合 */
		_, err = r.client.SremCtx(ctx, fmt.Sprintf("%s%v", r.up, uid), expiryTokens...)
		if err != nil {
			return nil, err
		}
	}
	return sessions, nil
}

// Commit adds a session token and data to the RedisStore instance with the
// given expiry time. If the session token already exists then the data and
// expiry time are updated.
func (r *RedisStore) Commit(uid string, token string, b []byte, expiry time.Time) error {
	return r.CommitCtx(context.Background(), uid, token, b, expiry)
}

func (r *RedisStore) CommitCtx(ctx context.Context, uid string, token string, b []byte, expiry time.Time) (err error) {
	err = r.client.TxPipelinedCtx(ctx, func(pipeliner redis.Pipeliner) (err error) {
		pipeliner.SAdd(ctx, fmt.Sprintf("%s%v", r.up, uid), token)
		/* session */
		pipeliner.Set(ctx, fmt.Sprintf("%s%v", r.sp, token), string(b), 0)
		pipeliner.ExpireAt(ctx, fmt.Sprintf("%s%v", r.sp, token), expiry)
		return nil
	})
	return err
}

// Delete removes a session token and corresponding data from the RedisStore
// instance.
func (r *RedisStore) Delete(uid string, tokens []string) error {
	return r.DeleteCtx(context.Background(), uid, tokens)
}

func (r *RedisStore) DeleteCtx(ctx context.Context, uid string, tokens []string) (err error) {
	sessionKeys := make([]string, 0)
	userTokens := make([]any, 0)

	if len(tokens) == 0 {
		/* token 列表为空 不操作 */
		return nil
	}

	for _, token := range tokens {
		sessionKeys = append(sessionKeys, fmt.Sprintf("%s%v", r.sp, token))
		userTokens = append(userTokens, any(token))
	}

	err = r.client.TxPipelinedCtx(ctx, func(pipeliner redis.Pipeliner) error {
		if len(sessionKeys) > 0 {
			/* 删除 sessionToken 列表 */
			pipeliner.Del(ctx, sessionKeys...)
		}
		if len(userTokens) > 0 {
			/* 删除 userToken 列表 */
			pipeliner.SRem(ctx, fmt.Sprintf("%s%v", r.up, uid), userTokens...)
		}
		return nil
	})
	return err
}

// Clear 清除 uid 列表中 token
func (r *RedisStore) Clear(uids []string) error {
	return r.ClearCtx(context.Background(), uids)
}

func (r *RedisStore) ClearCtx(ctx context.Context, uids []string) (err error) {
	if len(uids) == 0 {
		/* 列表为空 不操作 */
		return nil
	}

	delTokens := make([]string, 0)
	sessionTokenKeys := make([]string, 0)
	userTokenKeys := make([]string, 0)

	for _, uid := range uids {
		tokens, err := r.client.SmembersCtx(ctx, fmt.Sprintf("%s%v", r.up, uid))
		if err != nil {
			return err
		}
		delTokens = append(delTokens, tokens...)
		userTokenKeys = append(userTokenKeys, fmt.Sprintf("%s%v", r.up, uid))
	}
	for _, token := range delTokens {
		sessionTokenKeys = append(sessionTokenKeys, fmt.Sprintf("%s%v", r.sp, token))
	}

	err = r.client.TxPipelinedCtx(ctx, func(pipeliner redis.Pipeliner) error {
		/* 删除 sessionToken 列表 */
		if len(sessionTokenKeys) > 0 {
			pipeliner.Del(ctx, sessionTokenKeys...)
		}
		if len(userTokenKeys) > 0 {
			/* 删除 userToken 列表 */
			pipeliner.Del(ctx, userTokenKeys...)
		}
		return nil
	})
	return err
}

// All returns a map containing the token and data for all active (i.e.
// not expired) sessions in the RedisStore instance.
func (r *RedisStore) All() (map[string][]byte, error) {
	return r.AllCtx(context.Background())
}

func (r *RedisStore) AllCtx(ctx context.Context) (map[string][]byte, error) {
	var cursor uint64
	var keys []string
	for {
		var val []string
		var err error
		val, cursor, err = r.client.ScanCtx(ctx, cursor, r.sp+"*", 0)
		if err != nil {
			return nil, err
		}

		keys = append(keys, val...)

		// 没有更多key了
		if cursor == 0 {
			break
		}
	}

	sessions := make(map[string][]byte)

	for _, key := range keys {
		token := key[len(r.sp):]
		data, exists, err := r.GetCtx(ctx, token)
		if exists {
			sessions[token] = data
		}

		if err != nil {
			return nil, err
		}
	}
	return sessions, nil
}
